# Copyright 2020 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MusicVAE generation script."""

# TODO(adarob): Add support for models with conditioning.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
#exec(open("music_vae_generate.py").read())
import os
import sys
import time

from magenta import music as mm
import midi_io
import configs
from trained_model import TrainedModel
import numpy as np
import tensorflow.compat.v1 as tf
import pretty_midi
import midi_io_FIXED
import pickle

flags = tf.app.flags
logging = tf.logging
FLAGS = flags.FLAGS
"""
flags.DEFINE_string(
    'run_dir', "./",
    'Path to the directory where the latest checkpoint will be loaded from.')
flags.DEFINE_string(
    'checkpoint_file', None,
    'Path to the checkpoint file. run_dir will take priority over this flag.')
flags.DEFINE_string(
    'output_dir', 'music_vae/generated',
    'The directory where MIDI files will be saved to.')
flags.DEFINE_string(
    'config', "cat-mel_2bar_big",
    'The name of the config to use.')
flags.DEFINE_string(
    'mode', 'sample',
    'Generate mode (either `sample` or `interpolate`).')
flags.DEFINE_string(
    'input_midi_1', "",
    'Path of start MIDI file for interpolation.')
flags.DEFINE_string(
    'input_midi_2', "",
    'Path of end MIDI file for interpolation.')
flags.DEFINE_integer(
    'num_outputs', 5,
    'In `sample` mode, the number of samples to produce. In `interpolate` '
    'mode, the number of steps (including the endpoints).')
flags.DEFINE_integer(
    'max_batch_size', 8,
    'The maximum batch size to use. Decrease if you are seeing an OOM.')
flags.DEFINE_float(
    'temperature', 0.5,
    'The randomness of the decoding process.')
flags.DEFINE_string(
    'log', 'INFO',
    'The threshold for what messages will be logged: '
    'DEBUG, INFO, WARN, ERROR, or FATAL.')
"""

config = "cat-mel_2bar_small"
model = TrainedModel(
            configs.CONFIG_MAP[config], batch_size=2,
            checkpoint_dir_or_path="cat-mel_2bar_small/model.ckpt")

meas = pickle.load(open("pickles/meas16.pcl", "rb"))
inds = pickle.load(open("pickles/inds16.pcl", "rb"))
meas2 = [meas[i] for i in inds]
graph_vecs = []

for (graph_ind, graph) in enumerate(meas2):
    print(graph_ind)
    graph_vecs.append([])
    print(len(meas[graph_ind]))
    for (bar_ind, bar) in enumerate(meas[graph_ind]):
        bar = [(a[0], a[1]) for a in bar]
        input_1 = midi_io.midi_file_to_note_sequence(bar)
        mm.sequence_proto_to_midi_file(input_1, "tmpmids/tmp.mid")
        input_1 = midi_io_FIXED.midi_file_to_note_sequence("tmpmids/tmp.mid")


        #checkpoint_dir_or_path = os.path.expanduser(config)

        try:
            _, mu, _ = model.encode([input_1, input_1], False)
            graph_vecs[-1].append((mu))
        except:
            print("error")
            graph_vecs[-1].append(np.random.normal(size=256))
        
    if graph_ind % 50 == 0 or graph_ind == len(meas) - 1: 
        pickle.dump(graph_vecs, open("pickles/analyzedmagents.pcl", "wb"))

graph_vecs = []
for (graph_ind, graph) in enumerate(meas):
    print(graph_ind)
    for (bar_ind, bar) in enumerate(meas[graph_ind]):
        bar = [(a[0], a[1]) for a in bar]
        input_1 = midi_io.midi_file_to_note_sequence(bar)
        mm.sequence_proto_to_midi_file(input_1, "tmpmids/tmp.mid")
        input_1 = midi_io_FIXED.midi_file_to_note_sequence("tmpmids/tmp.mid")


        #checkpoint_dir_or_path = os.path.expanduser(config)

        try:
            _, mu, _ = model.encode([input_1, input_1], False)
            graph_vecs.append((mu))
        except:
            print("error")
            graph_vecs.append(np.random.normal(size=256))
        
    if graph_ind % 50 == 0 or graph_ind == len(meas) - 1:
        pickle.dump(graph_vecs, open("pickles/analyzedmagents2.pcl", "wb"))

